!! Copyright (C) 2009,2010,2011,2012  Marco Restelli
!!
!! This file is part of:
!!   FEMilaro -- Finite Element Method toolkit
!!
!! FEMilaro is free software; you can redistribute it and/or modify it
!! under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 3 of the License, or
!! (at your option) any later version.
!!
!! FEMilaro is distributed in the hope that it will be useful, but
!! WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
!! General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with FEMilaro; If not, see <http://www.gnu.org/licenses/>.
!!
!! author: Marco Restelli                   <marco.restelli@gmail.com>

!> \file
!! Simple ODE program to test the numerical integrators.
!!
!! \n
!!
!! This example has two purposes: set a common structure for the time
!! integrators and provide a specific example to test them.
!!
!! The building block for any time integrator is the type
!! <tt>t_time_state</tt>, which is used to represent the state of the
!! solution and is described in details in \c mod_time_integrators.
!<
module ode_test_problem

 use mod_messages, only: &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   wp

 use mod_state_vars, only: &
   c_stv

 use mod_time_integrators, only: &
   c_ode, c_ods

 implicit none

 public :: &
   t_stv, t_ode, c_ods, ode_exact

 private

 type, extends(c_stv) :: t_stv
  real(wp) :: f1(4)
  real(wp), allocatable :: f2(:)
 contains
  procedure, pass(x) :: incr
  procedure, pass(x) :: tims
  procedure, pass(z) :: copy
  procedure, pass(x) :: scal
  procedure, pass(x) :: show
  procedure, pass(x) :: source
  procedure, pass(x) :: source_vect
 end type t_stv

 type, extends(c_ode) :: t_ode
 contains
  procedure, pass(ode) :: rhs   => ode_rhs
  procedure, pass(ode) :: solve => ode_solve
 end type t_ode

 integer, parameter :: &
   nn     =  1
 real(wp), parameter :: &
   lambda = -3.0_wp, &
   tau    =  7.3_wp, &
   alpha  =  0.1_wp

contains

 subroutine ode_rhs(tnd,ode,t,uuu,ods,term)
  class(t_ode), intent(in)    :: ode !< ODE problem
  real(wp),     intent(in)    :: t   !< time level
  class(c_stv), intent(in)    :: uuu !< present state
  class(c_ods), intent(inout) :: ods !< scratch (diagnostic vars.)
  class(c_stv), intent(inout) :: tnd !< tendency
  integer,      intent(in), optional :: term(2)

  logical :: compute_f1, compute_f2

   select type(uuu); type is(t_stv); select type(tnd); type is(t_stv)

   compute_f1= .true.; compute_f2 = .true.
   if(present(term)) then
     if(term(1).eq.1) then
       compute_f1= .false.
     else
       compute_f2= .false.
     endif
   endif

   if(compute_f1) then
     tnd%f1(1) = lambda*uuu%f1(1)
     tnd%f1(2) = lambda*uuu%f1(2)**2
     tnd%f1(3) = lambda*sign(1.0_wp,uuu%f1(3))*sqrt(abs(uuu%f1(3)))
     tnd%f1(4) = alpha * (t**nn)
   else
     tnd%f1 = 0.0_wp
   endif

   if(compute_f2) then
     tnd%f2(1) = -tau*uuu%f2(2)
     tnd%f2(2) =  tau*uuu%f2(1)
   else
     tnd%f2 = 0.0_wp
   endif

   end select; end select
 end subroutine ode_rhs

 subroutine incr(x,y)
  class(c_stv), intent(in)    :: y
  class(t_stv), intent(inout) :: x

   select type(y); type is(t_stv)

   x%f1 = x%f1 + y%f1
   x%f2 = x%f2 + y%f2

   end select
 end subroutine incr

 subroutine tims(x,r)
  real(wp),     intent(in)    :: r
  class(t_stv), intent(inout) :: x

   x%f1 = r*x%f1
   x%f2 = r*x%f2

 end subroutine tims

 subroutine copy(z,x)
  class(c_stv), intent(in)    :: x
  class(t_stv), intent(inout) :: z

   select type(x); type is(t_stv)
   z%f1 = x%f1
   z%f2 = x%f2
   end select
 end subroutine copy

 function scal(x,y) result(s)
  class(t_stv), intent(in) :: x
  class(c_stv), intent(in) :: y
  real(wp) :: s

   select type(y); type is(t_stv)
    s = dot_product(x%f1,y%f1) + dot_product(x%f2,y%f2)
   end select
 end function scal

 subroutine show(x)
  class(t_stv), intent(in) :: x
   write(*,*) "f1: ", x%f1, "; f2: ", x%f2
 end subroutine show

 subroutine source(y,x)
  class(t_stv), intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y

   allocate(t_stv::y)

   select type(y); type is(t_stv)

   allocate(y%f2(size(x%f2)))

   end select
 end subroutine source
 subroutine source_vect(y,x,m)
  integer, intent(in) :: m
  class(t_stv), intent(in)  :: x
  class(c_stv), allocatable, intent(out) :: y(:)

  integer :: i

   allocate(t_stv::y(m))

   select type(y); type is(t_stv)

   do i=1,m
     allocate(y(i)%f2(size(x%f2)))
   enddo

   end select
 end subroutine source_vect

 subroutine ode_solve(x,ode,t,sigma,b,xl,ods)
  class(t_ode), intent(in)    :: ode
  real(wp),     intent(in) :: t, sigma
  class(c_stv), intent(in) :: b, xl
  class(c_stv), intent(inout) :: x
  class(c_ods), intent(inout) :: ods

  real(wp) :: v1, v2

   select type(x); type is(t_stv)
   select type(b); type is(t_stv); select type(xl); type is(t_stv)

   x%f1(1) = 1.0_wp/(1.0_wp-lambda*sigma) * b%f1(1)
   x%f1(2) = 1.0_wp/(1.0_wp-lambda*sigma*xl%f1(2)) * b%f1(2)
   v1 = sqrt(abs(xl%f1(3)))
   v2 = lambda*sigma/2.0_wp
   x%f1(3) = v1/(v1-v2)*( b%f1(3) + v2*sign(1.0_wp,xl%f1(3))*v1 )
   x%f1(4) = b%f1(4) + sigma*alpha * (t**nn)

   x%f2(1) = b%f2(1) - sigma*tau*b%f2(2)
   x%f2(2) = sigma*tau*b%f2(1) + b%f2(2)
   x%f2 = x%f2/(1.0_wp+(sigma*tau)**2)

   end select; end select
   end select
 end subroutine ode_solve

 pure subroutine ode_exact(t,y)
  real(wp),     intent(in) :: t
  class(t_stv), intent(inout) :: y

   y%f1(1) = 3.0_wp*exp(lambda*t)
   y%f1(2) = 2.0_wp/(1.0_wp-lambda*2.0_wp*t)
   y%f1(3) = max(lambda*t/2.0_wp+1.0_wp,0.0_wp)**2
   y%f1(4) = 0.01_wp + alpha/real(nn+1,wp) * (t**(nn+1))

   y%f2(1) = cos(tau*t)
   y%f2(2) = sin(tau*t)

 end subroutine ode_exact

end module ode_test_problem


program ode_test

 use mod_messages, only: &
   mod_messages_constructor, &
   mod_messages_destructor,  &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   mod_kinds_constructor, &
   mod_kinds_destructor,  &
   wp

 use mod_mpi_utils, only: &
   mod_mpi_utils_constructor, &
   mod_mpi_utils_destructor,  &
   mpi_init, mpi_finalize

 use mod_linal, only: &
   mod_linal_constructor, &
   mod_linal_destructor

 use mod_state_vars, only: &
   mod_state_vars_constructor, &
   mod_state_vars_destructor,  &
   c_stv

 use mod_time_integrators, only: &
   mod_time_integrators_constructor, &
   mod_time_integrators_destructor,  &
   c_ode, c_ods, c_tint, t_ti_init_data, t_ti_step_diag, &
   t_ee, t_hm, t_rk4, t_ssprk54, &
   t_thetam, t_bdf1, t_bdf2, t_bdf3, t_bdf2ex, &
   t_erb2kry, t_erb2lej, &
   t_pcexpo

 use ode_test_problem, only: &
   t_stv, t_ode, c_ods, ode_exact

 implicit none

 integer :: n, n0, i
 real(wp), parameter :: dt = 0.05_wp/2.0_wp 
 real(wp) :: t
 type(t_stv) :: uuu0, uuun, uuuex
 type(t_ode) :: ode
 type(c_ods) :: ods
 type(t_ti_init_data) :: init_data
 type(t_ti_step_diag) :: step_diag
 class(c_tint), allocatable :: integrator

 ! Note: the following commands are useful to check the results in
 ! octave
 !  load ode-test.out
 !  figure;plot(ode_test(:,1),ode_test(:,8:13),'k');hold on;plot(ode_test(:,1),ode_test(:,2:7),'o-');figure;plot(ode_test(:,1),ode_test(:,14:end))
 !  sum(abs(ode_test(:,14:end)),1)*4/size(ode_test,1)

  call mod_messages_constructor()
  call mod_kinds_constructor()
  call mpi_init(n0)
  call mod_mpi_utils_constructor()
  call mod_linal_constructor()
  call mod_state_vars_constructor()

  call mod_time_integrators_constructor()

  allocate(uuu0%f2(2),uuun%f2(2),uuuex%f2(2))
  call ode_exact(0.0_wp,uuu0) ! set the initial condition

  ! allocate also the work array
  open(22,file="ode-test.out", status='replace',action='write', &
       form='formatted',position='rewind')
  open(23,file="ode-test-diag.out", status='replace',action='write', &
       form='formatted',position='rewind')

  !allocate(t_ee::integrator)

  !allocate(t_hm::integrator)

  !allocate(t_rk4::integrator)

  !allocate(t_ssprk54::integrator)

  !allocate(t_thetam::integrator)
  !allocate(init_data%ir1(1)); init_data%ir1 = (/0.6_wp/) ! set theta

  !allocate(t_bdf1::integrator)

  !allocate(t_bdf2::integrator)

  !allocate(t_bdf3::integrator)

  !allocate(t_bdf2ex::integrator)

  !allocate(t_erb2kry::integrator)
  !init_data%dim = 25        ! size of the Krylov space
  !init_data%tol = 1.0e-9_wp ! tolerance

  !allocate(t_erb2lej::integrator)
  !init_data%dim = 25        ! maximum number of Leja points
  !init_data%tol = 1.0e-9_wp ! tolerance
  !allocate(init_data%ir1(2)); init_data%ir1 = (/-20.0_wp,0.1_wp/)

  allocate(t_pcexpo::integrator)
  init_data%dim = 25
  init_data%tol = 1.0e-9_wp
  allocate(init_data%ii1(1)); init_data%ii1 = (/2/) ! 1,2 -> kry,lej
  allocate(init_data%ir1(2)); init_data%ir1 = (/-20.0_wp,0.1_wp/)

  call integrator%init( ode,dt,0.0_wp,uuu0,ods,          &
                init_data=init_data , step_diag=step_diag)
  n0 = step_diag%bootstrap_steps
  do n=n0,nint(4.0_wp/dt)-1
    t = real(n,wp)*dt

    call integrator%step(uuun,ode,t,uuu0,ods,step_diag)
    uuu0 = uuun

    call ode_exact(t+dt,uuuex)
    write(22,'(19e14.6)') &
      t+dt,               &
      uuu0%f1,  uuu0%f2,  &
      uuuex%f1, uuuex%f2, &
      uuuex%f1 - uuu0%f1, uuuex%f2 - uuu0%f2

    if(allocated(step_diag%d1)) then
      write(23,'(a,e14.6)') 'Time: ',t+dt
      ! A note about the edit descriptor:
      ! t2    -> go to position 2
      ! a,l3  -> write a string and a logical
      ! /     -> start a new line
      write(23,'(t2,a,l3,/,t2,a,i3)') &
       'max_iter:', step_diag%max_iter, &
       'iterations:', step_diag%iterations
      ! concerning the edit descriptor:
      ! t4    -> go to position 4
      ! *     -> unlimited format item (f2008)
      ! :     -> stop writing as soon as there are no more values
      !          (otherwise " , " would be printed at the end of the
      !          line, before terminating the output). This works like
      !          an "exit" in the edit descriptor which exists as soon
      !          as the last item has been written
      ! " , " -> insert space,space
      write(23,'(t2,a,/,t4,*(e12.6,:," , "))') &
       'residuals:', step_diag%d1
      write(23,'(t2,a)') 'h:'
      do i=1,size(step_diag%d2,1)
        write(23,'(" ",*(e13.6,:," , "))') &
         step_diag%d2(i,:)
      enddo
      write(23,*) ''
    endif
  enddo

  call integrator%clean(ode,step_diag)
  deallocate(integrator)
  if(allocated(init_data%ii1)) deallocate(init_data%ii1)
  if(allocated(init_data%ir1)) deallocate(init_data%ir1)
  deallocate(uuu0%f2,uuun%f2,uuuex%f2)

  close(22)
  close(23)

  call mod_time_integrators_destructor()

  call mod_state_vars_destructor()
  call mod_linal_destructor()
  call mod_mpi_utils_destructor()
  call mpi_finalize(n0)
  call mod_kinds_destructor()
  call mod_messages_destructor()

end program ode_test
